"""
Author: Shaoyu Liu
Date: 2023/03/17
aggregate patentsberta embeddings for matrix multiplication for similarity

example use:
python3 gender_postprocess_embed.py 

inputs:
    embedding jsonl files
outputs:
    embedding_0_50.npz.npy
    embedding_100_150.npz.npy
    embedding_150_200.npz.npy
    embedding_200_250.npz.npy
    embedding_250_291.npz.npy
    embedding_50_100.npz.npy

    patent_id_0_50.csv
    patent_id_100_150.csv
    patent_id_150_200.csv
    patent_id_200_250.csv
    patent_id_250_291.csv
    patent_id_50_100.csv
"""

import os
import logging

os.chdir("/Volumes/Zihao_SSD2/PatentsView/")


def aggregate_wrapper(n_files=291, batch_size=50):
    """
    wrapper function to process all embedding jsonl
    """
    for idx in range(0, n_files, batch_size):
        idx_start, idx_end = idx, idx + 50 if idx + 50 < n_files else n_files
        logger.info(f"start batch: {idx_start}, end batch: {idx_end}")
        aggregate_files(
            idx_start=idx_start,
            idx_end=idx_end,
            out_dir="patentsberta",
            save=True,
        )
    logger.info("done.")


def aggregate_files(idx_start, idx_end, out_dir="patentsberta", save=False):
    """
    idx_start: starting batch index of patentsberta embedding jsonl
    idx_end: endding batch index of patentsberta embedding jsonl
    out_dir: output directory
    """
    import pandas as pd
    import numpy as np
    import json
    import glob
    import tqdm

    dt_list = []
    for i in tqdm.tqdm(range(idx_start, idx_end)):
        dt = []
        for line in open(f"patentsberta/embedding_jsonl/patent_embeddings_patentsberta_{i}.jsonl", "r"):
            dt.append(json.loads(line))
        dt_list += dt
        del dt

    df = pd.DataFrame(data=[r["patent_id"] for r in dt_list], columns=["patent_id"])
    ls_embedding = np.array([r["embedding"] for r in dt_list], dtype="float64")

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
        logger.info(f"directory not found, making directory {out_dir}")

    if save:
        path_df = os.path.join(out_dir, f"patent_id/patent_id_{idx_start}_{idx_end}.csv")
        path_embedding = os.path.join(out_dir, f"embedding_{idx_start}_{idx_end}.npz")
        df.to_csv(path_df, index=False)
        np.save(path_embedding, ls_embedding)
        logger.info("data saved!")
    else:
        return df, ls_embedding


def main():
    aggregate_wrapper()


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s:%(levelname)s:%(message)s", level=logging.INFO)
    logger = logging.getLogger(__name__)
    main()
